/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/*! \file
    \brief Defines container classes and iterators for managing a statically
   sized vector of boolean predicates.
*/
#pragma once

#if defined(__CUDACC_RTC__)
#include <cuda/std/cassert>
#include <cuda/std/cstdint>
#else
#include <assert.h>
#include <stdint.h>
#endif

#include "cutlass/cutlass.h"

#include "cutlass/platform/platform.h"

namespace cutlass {

////////////////////////////////////////////////////////////////////////////////////////////////////

/*!@defgroup predicate_vector_concept Predicate Vector Concept
@{

Implementations of \ref predicate_vector_concept contain an ordered set of
boolean predicates which may be used as conditionals in other device-side
operations. Both random access and iterators offering sequential access are
provided.

@par Predicate Vector
   A \ref predicate_vector_concept satisfies the following expressions
  - <b>at(int idx)</b> - returns the value of the indexed predicate
  - <b>set(int idx, bool value)</b> - sets the value of the indexed predicate
  - <b>begin()</b> - returns a \ref predicate_iterator_concept pointing to the
first predicate

@}
*/

////////////////////////////////////////////////////////////////////////////////////////////////////

/*!@defgroup predicate_iterator_concept Predicate Iterator Concept
@{

Implementations of \ref predicate_iterator_concept enables accessing and
traversing elements of a bit vector.

@par Const Predicate Iterator
  A const \ref predicate_iterator_concept satisfies the following expressions
 - <b>++it</b> increments the iterator to the next predicate
 - <b>*it</b> returns the value of the currently pointed-to predicate

@par Mutable Predicate Iterator
 A \ref predicate_iterator_concept that is non-const <b>also</b> satisfies the
following expressions
 - <b>it.set(bool value)</b> sets the value of the currently pointed-to
predicate

@}
*/

////////////////////////////////////////////////////////////////////////////////////////////////////

/*!@defgroup predicate_tile_adapter Predicate Tile Adapter Concept
@{

Implementations of \ref predicate_tile_adapter provide a mapping between a the
elements of a \ref tile_traits_concept and a \ref predicate_vector_concept.

@par Predicate Tile Adapter
  A \ref predicate_tile_adapter satisfies the following expressions
 - <b>at(int d, int h, int w, int c)</b> - returns the value of a predicate
corresponding to the access (d, h, w, c) within the tile.

@}
*/

////////////////////////////////////////////////////////////////////////////////////////////////////

/// Statically sized array of bits implementing
/// @concept{predicate_vector_concept}.
template <
        /// Number of predicates conatined in predicate vector
        int kPredicates_,
        /// Number of predicates contained in each byte of internal storage
        int kPredicatesPerByte_ = 4,
        /// Location of first predicate within byte of internal storage
        int kPredicateStart_ = 0>
struct PredicateVector {
    /// Number of bits stored by the PredicateVector
    static int const kPredicates = kPredicates_;

    /// Number of bits stored within each byte of the predicate bit vector
    static int const kPredicatesPerByte = kPredicatesPerByte_;

    /// First bit withing each byte containing predicates
    static int const kPredicateStart = kPredicateStart_;

    // Make sure no one tries to put more than 8 bits in a byte :)
    static_assert(kPredicatesPerByte <= 8,
                  "kPredicatesPerByte must fit within an actual byte");
    // Make sure the "offsetted" bits fit in one byte.
    static_assert(kPredicateStart + kPredicatesPerByte <= 8,
                  "The offsetted predicates must fit within an actual byte.");

    /// Storage type of individual elements
    typedef uint32_t Storage;

    /// Number of bytes needed
    static int const kBytes =
            (kPredicates + kPredicatesPerByte - 1) / kPredicatesPerByte;

    /// Number of storage elements needed
    static int const kWordCount =
            (kBytes + sizeof(Storage) - 1) / sizeof(Storage);

private:
    //
    // Data members
    //

    /// Words of bit vector
    Storage storageData[kWordCount];

    //
    // Methods
    //

    /// Computes the word and bit corresponding to a logical predicate index
    CUTLASS_HOST_DEVICE void computeStorageOffset(int& word, int& bit,
                                                  int idx) const {
        CUTLASS_ASSERT(idx < kPredicates);

        int byte = (idx / kPredicatesPerByte);
        int bit_offset = (idx % kPredicatesPerByte);

        word = byte / sizeof(Storage);
        int byte_offset = (byte % sizeof(Storage));

        bit = byte_offset * 8 + bit_offset + kPredicateStart;
    }

    /// Accesses a given word with optional assertions
    CUTLASS_HOST_DEVICE Storage& storage(int word) {
        CUTLASS_ASSERT(word < kWordCount);
        return storageData[word];
    }

    /// Accesses a given word with optional assertions
    CUTLASS_HOST_DEVICE Storage const& storage(int word) const {
        CUTLASS_ASSERT(word < kWordCount);
        return storageData[word];
    }

public:
    //
    // Iterator
    //

    /**
     * @brief An iterator implementing \ref predicate_iterator_concept enabling
     * sequential read and write access to predicates.
     * @concept{predicate_iterator_concept}
     */
    class Iterator {
        /// Reference to PredicateVector instance
        PredicateVector& vec_;

        /// Index into PredicateVector
        int bit_;

    public:
        /// Copy constructor
        CUTLASS_HOST_DEVICE
        Iterator(Iterator const& it) : vec_(it.vec_), bit_(it.bit_) {}

        /// Constructs an iterator from a PredicateVector
        CUTLASS_HOST_DEVICE
        Iterator(PredicateVector& vec, int _start = 0)
                : vec_(vec), bit_(_start) {}

        /// Pre-increment
        CUTLASS_HOST_DEVICE
        Iterator& operator++() {
            ++bit_;
            return *this;
        }

        /// Increment
        CUTLASS_HOST_DEVICE
        Iterator& operator+=(int offset) {
            bit_ += offset;
            return *this;
        }

        /// Pre-decrement
        CUTLASS_HOST_DEVICE
        Iterator& operator--() {
            --bit_;
            return *this;
        }

        /// Decrement
        CUTLASS_HOST_DEVICE
        Iterator& operator-=(int offset) {
            bit_ -= offset;
            return *this;
        }

        /// Post-increment
        CUTLASS_HOST_DEVICE
        Iterator operator++(int) {
            Iterator ret(*this);
            ret.bit_++;
            return ret;
        }

        /// Post-decrement
        CUTLASS_HOST_DEVICE
        Iterator operator--(int) {
            Iterator ret(*this);
            ret.bit_--;
            return ret;
        }

        /// Iterator advances by some amount
        CUTLASS_HOST_DEVICE
        Iterator operator+(int offset) {
            Iterator ret(*this);
            ret.bit_ += offset;
            return ret;
        }

        /// Iterator recedes by some amount
        CUTLASS_HOST_DEVICE
        Iterator operator-(int offset) {
            ConstIterator ret(*this);
            ret.bit_ -= offset;
            return ret;
        }

        /// Returns true if iterators point to the same bit
        CUTLASS_HOST_DEVICE
        bool operator==(Iterator const& it) const { return bit_ == it.bit_; }

        /// Returns false if iterators point to the same bit
        CUTLASS_HOST_DEVICE
        bool operator!=(Iterator const& it) const { return bit_ != it.bit_; }

        /// Gets the bit at the pointed to location
        CUTLASS_HOST_DEVICE
        bool get() { return vec_.at(bit_); }

        /// Gets the bit at the pointed to location
        CUTLASS_HOST_DEVICE
        bool at() const { return vec_.at(bit_); }

        /// Dereferences iterator
        CUTLASS_HOST_DEVICE
        bool operator*() const { return at(); }

        /// Sets the bit at the pointed to location
        CUTLASS_HOST_DEVICE
        void set(bool value = true) { vec_.set(bit_, value); }
    };

    /**
     * @brief An iterator implementing \ref predicate_iterator_concept enabling
     * sequential read and write access to predicates.
     * @concept{predicate_iterator_concept}
     */
    class ConstIterator {
        /// Reference to PredicateVector instance
        PredicateVector const& vec_;

        /// Index into PredicateVector
        int bit_;

    public:
        /// Copy constructor
        CUTLASS_HOST_DEVICE
        ConstIterator(ConstIterator const& it) : vec_(it.vec_), bit_(it.bit_) {}

        /// Constructs an iterator from a PredicateVector
        CUTLASS_HOST_DEVICE
        ConstIterator(PredicateVector const& vec, int _start = 0)
                : vec_(vec), bit_(_start) {}

        /// Pre-increment
        CUTLASS_HOST_DEVICE
        ConstIterator& operator++() {
            ++bit_;
            return *this;
        }

        /// Increment
        CUTLASS_HOST_DEVICE
        ConstIterator& operator+=(int offset) {
            bit_ += offset;
            return *this;
        }

        /// Pre-decrement
        CUTLASS_HOST_DEVICE
        ConstIterator& operator--() {
            --bit_;
            return *this;
        }

        /// Decrement
        CUTLASS_HOST_DEVICE
        ConstIterator& operator-=(int offset) {
            bit_ -= offset;
            return *this;
        }

        /// Post-increment
        CUTLASS_HOST_DEVICE
        ConstIterator operator++(int) {
            ConstIterator ret(*this);
            ret.bit_++;
            return ret;
        }

        /// Post-decrement
        CUTLASS_HOST_DEVICE
        ConstIterator operator--(int) {
            ConstIterator ret(*this);
            ret.bit_--;
            return ret;
        }

        /// Iterator advances by some amount
        CUTLASS_HOST_DEVICE
        ConstIterator operator+(int offset) {
            ConstIterator ret(*this);
            ret.bit_ += offset;
            return ret;
        }

        /// Iterator recedes by some amount
        CUTLASS_HOST_DEVICE
        ConstIterator operator-(int offset) {
            ConstIterator ret(*this);
            ret.bit_ -= offset;
            return ret;
        }

        /// Returns true if iterators point to the same bit
        CUTLASS_HOST_DEVICE
        bool operator==(ConstIterator const& it) const {
            return bit_ == it.bit_;
        }

        /// Returns false if iterators point to the same bit
        CUTLASS_HOST_DEVICE
        bool operator!=(ConstIterator const& it) const {
            return bit_ != it.bit_;
        }

        /// Gets the bit at the pointed to location
        CUTLASS_HOST_DEVICE
        bool get() { return vec_.at(bit_); }

        /// Gets the bit at the pointed to location
        CUTLASS_HOST_DEVICE
        bool at() const { return vec_.at(bit_); }

        /// Dereferences iterator
        CUTLASS_HOST_DEVICE
        bool operator*() const { return at(); }
    };

    /// Iterator that always returns true
    struct TrivialIterator {
        /// Constructor
        CUTLASS_HOST_DEVICE
        TrivialIterator() {}

        /// Copy constructor
        CUTLASS_HOST_DEVICE
        TrivialIterator(Iterator const& it) {}

        /// Constructs an iterator from a PredicateVector
        CUTLASS_HOST_DEVICE
        TrivialIterator(PredicateVector const& _vec) {}

        /// Pre-increment
        CUTLASS_HOST_DEVICE
        TrivialIterator& operator++() { return *this; }

        /// Post-increment
        CUTLASS_HOST_DEVICE
        TrivialIterator operator++(int) { return *this; }

        /// Dereferences iterator
        CUTLASS_HOST_DEVICE
        bool operator*() const { return true; }
    };

public:
    //
    // Methods
    //

    /// Initialize the predicate vector
    CUTLASS_HOST_DEVICE PredicateVector(bool value = true) { fill(value); }

    /// Fills all predicates with a given value
    CUTLASS_HOST_DEVICE void fill(bool value = true) {
        Storage item = (value ? ~Storage(0) : Storage(0));

        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kWordCount; ++i) {
            storage(i) = item;
        }
    }

    /// Clears all predicates
    CUTLASS_HOST_DEVICE void clear() {
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kWordCount; ++i) {
            storage(i) = 0;
        }
    }

    /// Sets all predicates to true
    CUTLASS_HOST_DEVICE void enable() {
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kWordCount; ++i) {
            storage(i) = ~Storage(0);
        }
    }

    /// Accesses a bit within the predicate vector.
    CUTLASS_HOST_DEVICE bool operator[](int idx) const { return at(idx); }

    /// Accesses a bit within the predicate vector.
    CUTLASS_HOST_DEVICE bool at(int idx) const {
        int bit, word;
        computeStorageOffset(word, bit, idx);

        return ((storage(word) >> bit) & 1);
    }

    /// Set a bit within the predicate vector.
    CUTLASS_HOST_DEVICE void set(int idx, bool value = true) {
        int bit, word;
        computeStorageOffset(word, bit, idx);

        Storage disable_mask = (~(Storage(1) << bit));
        Storage enable_mask = (Storage(value) << bit);

        storage(word) = ((storage(word) & disable_mask) | enable_mask);
    }

    /// Computes the intersection of two identical predicate vectors.
    CUTLASS_HOST_DEVICE PredicateVector& operator&=(
            PredicateVector const& predicates) {
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kWordCount; ++i) {
            storage(i) = (storage(i) & predicates.storage(i));
        }
        return *this;
    }

    /// Computes the union of two identical predicate vectors.
    CUTLASS_HOST_DEVICE PredicateVector& operator|=(
            PredicateVector const& predicates) {
        CUTLASS_PRAGMA_UNROLL
        for (int i = 0; i < kWordCount; ++i) {
            storage(i) = (storage(i) | predicates.storage(i));
        }
        return *this;
    }

    /// Returns true if entire predicate array is zero.
    CUTLASS_HOST_DEVICE bool is_zero() const {
        Storage mask(0);
        for (int byte = 0; byte < sizeof(Storage); ++byte) {
            Storage byte_mask =
                    (((1 << kPredicatesPerByte) - 1) << kPredicateStart);
            mask |= (byte_mask << (byte * 8));
        }
        uint32_t result = 0;
        for (int word = 0; word < kWordCount; ++word) {
            result |= storage(word);
        }
        return result == 0;
    }

    /// Returns an iterator to the start of the bit vector
    CUTLASS_DEVICE
    Iterator begin() { return Iterator(*this); }

    /// Returns an iterator
    CUTLASS_DEVICE
    Iterator end() { return Iterator(*this, kPredicates); }

    /// Returns a ConstIterator
    CUTLASS_DEVICE
    ConstIterator const_begin() const { return ConstIterator(*this); }

    /// Returns a ConstIterator
    CUTLASS_DEVICE
    ConstIterator const_end() const {
        return ConstIterator(*this, kPredicates);
    }
};

////////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace cutlass
